import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.decomposition import PCA
import trunc_functions as trun 


def gaussian_signal(Depth, mu, sig):
    x =np.arange(Depth)
    return np.exp(-np.power((x - mu)/sig, 2.) / 2)

def signal(D,Start,End,Fract,NoiseSigma):
    Sigmas =5 #+-sigma in range 
    Sig =(End -Start)/2/Sigmas
    Mu = (Start +End)/2
    
    spec =np.arange(D) 
    #spec = gaussian(spec,Mu,Sig)*Fract 
    
    if NoiseSigma >0:
        spec += np.random.normal(0,NoiseSigma,D)
        
    return spec

def compound_layer(Height,Width):
    axis = np.arange(Width)
    profile = (1 - np.cos(2 * np.pi * axis /Width)) /2 #sinusoidal profile  
    map_c =np.ones((Height,Width)) 
    map_c *=profile #sinusoidal distribution from left to right
    
    return map_c

def layers_fragment(Height,Width):
    fragment =np.zeros((Height,Width,3))
    HW =Width //3
    fragment[:,:HW,0] =compound_layer(Height,2*HW)[:,HW:]   #right half of A
    fragment[:,:2*HW,1] =compound_layer(Height,2*HW)        # B compound
    fragment[:,HW:,2] =compound_layer(Height,2*HW)          # C compound 
    fragment[:,2*HW:,0] =compound_layer(Height,2*HW)[:,:HW] #left half of A
    
    return fragment


def make_SI_3features(im,Depth,SignalSigma,NoiseSigma):
    Height =im.shape[0]
    Width  =im.shape[1]    
    imSI =np.zeros((Height,Width,Depth))
    
    for y in range(Height):
        for x in range(Width):
            #print(x,y)          
            Feature_A = im[y,x,0]
            Feature_B = im[y,x,1]
            Feature_C = im[y,x,2]                 
            spec = Feature_A*gaussian_signal(Depth,Depth/4,SignalSigma) #add 1st feature            
            spec += Feature_B*gaussian_signal(Depth,2*Depth/4,SignalSigma) #add 2nd
            spec += Feature_C*gaussian_signal(Depth,3*Depth/4,SignalSigma) #add 3rd
            if NoiseSigma >0: #add Gaussian noise
                spec += np.random.normal(0,NoiseSigma,Depth)
            imSI[y,x,:] =spec
            
    return imSI


def scatterplot(scores,First,Second):
    plt.scatter(scores[:,First-1],scores[:,Second-1],s=1)
    ax = plt.gca()
    ax.set_aspect('equal')
    plt.show()



Width =90
Height=100
maps =np.zeros((Height,Width,3))
for i in range(3): maps[:,i*(Width//3):(i+1)*(Width//3),:] = layers_fragment(100,30)

plt.imshow(Image.fromarray((255*maps).astype('uint8')))
plt.show()

Depth =1000
SignalSigma=50
NoiseSigma=0.5
SI = make_SI_3features(maps,Depth,SignalSigma,NoiseSigma)
spec =SI[50,46,:]
spec.shape =(Depth,)
plt.plot(np.arange(Depth),spec)
plt.show()

Matrix = SI.copy()
Matrix.shape =(Height*Width,Depth)

Extracted_components =10
pca = PCA(n_components=Extracted_components)
pca.fit(Matrix)
scores =pca.transform(Matrix)
print(scores.shape)

scatterplot(scores,1,2)

aniso_plot = trun.anisotropy_plot(scores)
plt.plot(np.arange(Extracted_components-1)+1,aniso_plot)
plt.show()

cut,_,_,_ = trun.auto_cut(aniso_plot)
print('truncated at', cut,'components,','components', cut+1, 'and higher are noise' )
